{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
     ]
    }
   ],
   "source": [
    "from pytorch_pretrained_bert import BertTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('../../../bert-base-chinese/vocab_unk.txt', do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_positions(data_list, map_str):\n",
    "    map_str = tokenizer.tokenize(map_str)\n",
    "    map_str = [i.replace('#', '') for i in map_str]\n",
    "    map_str = ''.join(map_str)\n",
    "    # 如果只由一个词组成\n",
    "    for word in data_list:\n",
    "        if map_str.lower() in word.lower():\n",
    "            start_id = end_id = data_list.index(word)\n",
    "            return start_id, end_id\n",
    "\n",
    "    start_id = -1\n",
    "    end_id = -1\n",
    "    for idx, word in enumerate(data_list):\n",
    "        if start_id != - 1 and end_id != -1:\n",
    "            return start_id, end_id\n",
    "        if map_str.startswith(word):\n",
    "            start_id = end_id = idx\n",
    "            while end_id+1 < len(data_list) and data_list[end_id+1] in map_str:\n",
    "                if \"\".join(data_list[start_id:end_id+2]) == map_str:\n",
    "                    print(\"\".join(data_list[start_id:end_id+2]))\n",
    "                    return start_id, end_id+1\n",
    "                end_id += 1\n",
    "            find_str = \"\"\n",
    "            for idx in range(start_id, end_id+1):\n",
    "                find_str = find_str + data_list[idx]\n",
    "            if find_str != map_str:\n",
    "                pre_extend = (data_list[start_id-1] if start_id > 0 else \"\") + find_str\n",
    "                last_extend = find_str + (data_list[end_id+1] if end_id < len(data_list)-1 else \"\")\n",
    "                pre_last_extend = (data_list[start_id-1] if start_id > 0 else \"\")+ find_str + (data_list[end_id+1] if end_id < len(data_list)-1 else \"\")\n",
    "                print(\"need:{} \\nfind:{} \\n pre:{}\\n last:{} \\n pre_las:{}\".format(map_str,find_str, pre_extend,pre_extend,pre_last_extend))\n",
    "                if map_str in pre_extend:\n",
    "                    start_id -= 1\n",
    "                elif map_str in last_extend:\n",
    "                    end_id += 1\n",
    "                elif map_str in pre_last_extend:\n",
    "                    start_id -= 1\n",
    "                    end_id += 1\n",
    "                else:\n",
    "                    start_id = -1\n",
    "                    end_id = -1\n",
    "    if start_id != -1 and end_id != -1:\n",
    "        return start_id, end_id\n",
    "    for idx, word in enumerate(data_list[:-1]):\n",
    "        if map_str in (word+data_list[idx+1]):\n",
    "            return idx, idx+1\n",
    "    # print(\"word_list{}  map_str {} loss\".format(data_list, map_str))\n",
    "    return start_id, end_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = '《李烈钧自述》是2011年11月1日人民日报出版社出版的图书，作者是李烈钧'\n",
    "word_list = tokenizer.tokenize(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "人民日报出版社\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(14, 20)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_positions(word_list, '人民日报出版社')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['机', '械', '工', '业', '出', '版', '社']"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word_list[10:17]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
